import os
import sys
import glob
import h5py
import numpy as np
import json
import cv2
import mindspore.dataset as ds
np.random.seed(1234)

def download_S3DIS():

    BASE_DIR=os.path.dirname(os.path.abspath(__file__))
    DATA_DIR=os.path.join(BASE_DIR,"data")
    if not os.path.exists(DATA_DIR):
        os.mkdir(DATA_DIR)
    if not os.path.exists(os.path.join(DATA_DIR, 'indoor3d_sem_seg_hdf5_data')):
        www = 'https://shapenet.cs.stanford.edu/media/indoor3d_sem_seg_hdf5_data.zip'
        zipfile = os.path.basename(www)
        os.system('wget %s --no-check-certificate; unzip %s' % (www, zipfile))
        os.system('mv %s %s' % ('indoor3d_sem_seg_hdf5_data', DATA_DIR))
        os.system('rm %s' % (zipfile))
    if not os.path.exists(os.path.join(DATA_DIR, 'Stanford3dDataset_v1.2_Aligned_Version')):
        if not os.path.exists(os.path.join(DATA_DIR, 'Stanford3dDataset_v1.2_Aligned_Version.zip')):
            print('Please download Stanford3dDataset_v1.2_Aligned_Version.zip \
                from https://goo.gl/forms/4SoGp4KtH1jfRqEj2 and place it under data/')
            sys.exit(0)
        else:
            zippath = os.path.join(DATA_DIR, 'Stanford3dDataset_v1.2_Aligned_Version.zip')
            os.system('unzip %s' % (zippath))
            os.system('mv %s %s' % ('Stanford3dDataset_v1.2_Aligned_Version', DATA_DIR))
            os.system('rm %s' % (zippath))

def prepare_test_data_semseg():

    BASE_DIR=os.path.dirname(os.path.abspath(__file__))
    DATA_DIR=os.path.join(BASE_DIR,'data')
    if not os.path.exists(os.path.join(DATA_DIR, 'stanford_indoor3d')):
        os.system('python prepare_data/collect_indoor3d_data.py')
    if not os.path.exists(os.path.join(DATA_DIR, 'indoor3d_sem_seg_hdf5_data_test')):
        os.system('python prepare_data/gen_indoor3d_h5.py')

def load_data_semseg(split, test_area):
    BASE_DIR = os.path.dirname(os.path.abspath(__file__))
    DATA_DIR = os.path.join(BASE_DIR, 'data')
    download_S3DIS()
    prepare_test_data_semseg()
    if split == 'train':
        data_dir = os.path.join(DATA_DIR, 'indoor3d_sem_seg_hdf5_data')
    else:
        data_dir = os.path.join(DATA_DIR, 'indoor3d_sem_seg_hdf5_data_test')
    with open(os.path.join(data_dir, "all_files.txt")) as f:
        all_files = [line.rstrip() for line in f]
    with open(os.path.join(data_dir, "room_filelist.txt")) as f:
        room_filelist = [line.rstrip() for line in f]
    data_batchlist, label_batchlist = [], []
    for f in all_files:
        file = h5py.File(os.path.join(DATA_DIR, f), 'r+')
        data = file["data"][:]
        label = file["label"][:]
        data_batchlist.append(data)
        label_batchlist.append(label)
    data_batches = np.concatenate(data_batchlist, 0)
    seg_batches = np.concatenate(label_batchlist, 0)
    test_area_name = "Area_" + str(test_area)
    train_idxs, test_idxs = [], []
    for i, room_name in enumerate(room_filelist):
        if test_area_name in room_name:
            test_idxs.append(i)
        else:
            train_idxs.append(i)
    if split == 'train':
        all_data = data_batches[train_idxs, ...]
        all_seg = seg_batches[train_idxs, ...]
    else:
        all_data = data_batches[test_idxs, ...]
        all_seg = seg_batches[test_idxs, ...]
    return all_data, all_seg


def translate_pointcloud(pointcloud):
    xyz1 = np.random.uniform(low=2. / 3., high=3. / 2., size=[3])
    xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3])

    translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32')
    return translated_pointcloud

def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.02):
    N, C = pointcloud.shape
    pointcloud += np.clip(sigma * np.random.randn(N, C), -1*clip, clip)
    return pointcloud

def rotate_pointcloud(pointcloud):

    theta = np.pi*2 * np.random.uniform()
    rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)],[np.sin(theta), np.cos(theta)]])
    pointcloud[:,[0,2]] = pointcloud[:,[0,2]].dot(rotation_matrix) # random rotation (x,z)
    return pointcloud

class S3DISDataset:
    """
    A source dataset that reads, parses and augments the S3DIS dataset.

    Args:
        path (str): The root directory of the ModelNet40 dataset or inference pointcloud.
        split (str): The dataset split, supports "train", "val", or "infer". Default: "train".
        transform (callable, optional):A function transform that takes in a pointcloud. Default: None.
        target_transform (callable, optional):A function transform that takes in a label. Default: None.
        batch_size (int): The batch size of dataset. Default: 64.
        resize (Union[int, tuple]): The output size of the resized image. If size is an integer, the smaller edge of the
            image will be resized to this value with the same image aspect ratio. If size is a sequence of length 2,
            it should be (height, width). Default: 224.
        repeat_num (int): The repeat num of dataset. Default: 1.
        shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Default: None.
        download (bool): Whether to download the dataset. Default: False.
        mr_file (str, optional): The path of mindrecord files. Default: False.
        columns_list (tuple): The column names of output data. Default: ('image', 'image_id', "label").
        num_parallel_workers (int, optional): The number of subprocess used to fetch the dataset
            in parallel.Default: None.
        num_shards (int, optional): The number of shards that the dataset will be divided. Default: None.
        shard_id (int, optional): The shard ID within num_shards. Default: None.

    Raises:
        ValueError: If `split` is not 'train', 'test' or 'infer'.

    Examples:
        >>>from mindvision.ms3d.dataset import ModelNet40
        >>>dataset = ModelNet40("./data/S3DIS/","train")
        >>>dataset=dataset.run()

    About S3DIS dataset:
    A brief description of s3dis data set: in 271 rooms in 6 areas, make ⽤ matterport camera (combined with 3
    structured light sensors with different spacing) heavy after scanning Create 3D texture grids, rgb-d images and
    other data, and make point clouds by sampling grids. A semantic label is added to each point in the point cloud
    (such as chair, table, floor, wall, etc., a total of 13 objects)

    .. code-block::
    ./S3DIS/
        ├── indoor3d_sem_seg_hdf5_data
        │   ├── ply_data_all_0.h5
        │   ├── ply_data_all_1.h5
        │   ├── ...
        │   └── all_files.bak.txt
        ├── indoor3d_sem_seg_hdf5_data_test
        │   └── raw_data3d
        │         ├── Area_1
        │         ├── Area_2
        │         ├── ...
        │   ├── ply_data_all_0.h5
        │   ├── ply_data_all_1.h5
        │   ├── ...
        │   ├── all_files.txt
        │   └── room_filelist.txt
        ├── Stanford3dDataset_v1.2_Aligned_Version
        │   ├── Area_1
        │   ├── Area_2
        │   └── ...
        └── stanford_indoor3d
            ├── Area_1_conferenceRoom_1.npy
            ├──Area_1_conferenceRoom_2.npy
            └── ...
    """
    def __init__(self,num_points=4096,split="train",test_area="1"):
        self.data,self.seg=load_data_semseg(split,test_area)
        self.num_points=num_points
        self.split=split

    def __getitem__(self, item):
        pointcloud=self.data[item][:self.num_points]
        seg=self.seg[item][:self.num_points]
        if self.split=="train":
            indices=list(range(pointcloud.shape[0]))
            np.random.shuffle(indices)
            pointcloud=pointcloud[indices]
            seg=seg[indices]
        seg=seg.astype(np.int32)
        return pointcloud,seg

    def __len__(self):
        return self.data.shape[0]

if __name__ == '__main__':


    dataset_generator = S3DISDataset(4096, 'train')
    data = dataset_generator[0]
    points = data[0]
    label = data[1]
    # print(data[1].shape)
    dataset = ds.GeneratorDataset(dataset_generator, ["data", "label"], shuffle=False)
    dataset = dataset.batch(4)
    for data in dataset.create_dict_iterator():
        pointcloud = data['data'].asnumpy()
        label = data['label'].asnumpy()
        print(pointcloud.shape)
        print(label.shape)
    print('data size:', dataset.get_dataset_size())